import math
from typing import Callable, Dict, List, Optional, Tuple

import numpy as np
import PIL
import torch
import torch.nn.functional as F

import torch.nn as nn



class ConvInjectedLinear(nn.Module):
    def __init__(self, in_channels, out_channels, bias=False, r=4, r2 = 64):
        super().__init__()

        # if r > min(in_features, out_features):
        #     raise ValueError(
        #         f"Vida rank {r} must be less or equal than {min(in_features, out_features)}"
        #     )

        self.Vidaconv = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=bias)
        self.Vida_down = nn.Conv2d(in_channels, r, kernel_size=(1, 1), stride=(1, 1), bias=False)
        self.Vida_up = nn.Conv2d(r, out_channels, kernel_size=(1, 1), stride=(1, 1), bias=False)
        self.Vida_down2 = nn.Conv2d(in_channels, r2, kernel_size=(1, 1), stride=(1, 1), bias=False)
        self.Vida_up2 = nn.Conv2d(r2, out_channels, kernel_size=(1, 1), stride=(1, 1), bias=False)
        self.scale = 1.0

        nn.init.normal_(self.Vida_down.weight, std=1 / r**2)
        nn.init.zeros_(self.Vida_up.weight)

        nn.init.normal_(self.Vida_down2.weight, std=1 / r**2)
        nn.init.zeros_(self.Vida_up2.weight)

    def forward(self, input):
        return self.Vidaconv(input) + self.Vida_up(self.Vida_down(input)) * self.scale + self.Vida_up2(self.Vida_down2(input)) * self.scale


class ConvInjectedLinear_group(nn.Module):
    def __init__(self, in_channels, out_channels, bias=False, r=4, r2 = 64):
        super().__init__()

        # if r > min(in_features, out_features):
        #     raise ValueError(
        #         f"Vida rank {r} must be less or equal than {min(in_features, out_features)}"
        #     )

        self.Vidaconv = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups = 4, bias=bias)
        self.Vida_down = nn.Conv2d(in_channels, r, kernel_size=(1, 1), stride=(1, 1), bias=False)
        self.Vida_up = nn.Conv2d(r, out_channels, kernel_size=(1, 1), stride=(1, 1), bias=False)
        self.Vida_down2 = nn.Conv2d(in_channels, r2, kernel_size=(1, 1), stride=(1, 1), bias=False)
        self.Vida_up2 = nn.Conv2d(r2, out_channels, kernel_size=(1, 1), stride=(1, 1), bias=False)
        self.scale = 1.0

        nn.init.normal_(self.Vida_down.weight, std=1 / r**2)
        nn.init.zeros_(self.Vida_up.weight)

        nn.init.normal_(self.Vida_down2.weight, std=1 / r**2)
        nn.init.zeros_(self.Vida_up2.weight)

    def forward(self, input):
        return self.Vidaconv(input) + self.Vida_up(self.Vida_down(input)) * self.scale + self.Vida_up2(self.Vida_down2(input)) * self.scale



def inject_trainable_Vida(
    model: nn.Module,
    target_replace_module: List[str] = ["CrossAttention", "Attention"],
    target_name: str = 'conv2',
    r: int = 4,
    r2: int = 16,
):
    """
    inject Vida into model, and returns Vida parameter groups.
    """

    require_grad_params = []
    names = []
    target_stride = (1, 1)
    for _module in model.modules():
        if _module.__class__.__name__ in target_replace_module:

            for name, _child_module in _module.named_modules():
                if name == target_name and _child_module.stride == target_stride:
                    weight = _child_module.weight
                    bias = _child_module.bias
                    if name =='conv_conv':
                        _tmp = ConvInjectedLinear_group(
                                _child_module.in_channels,
                                _child_module.out_channels,
                                _child_module.bias is not None,
                                r,
                                r2,
                            )
                    else:
                        _tmp = ConvInjectedLinear(
                            _child_module.in_channels,
                            _child_module.out_channels,
                            _child_module.bias is not None,
                            r,
                            r2,
                        )
                    _tmp.Vidaconv.weight = weight
                    if bias is not None:
                        _tmp.linearVida.bias = bias

                    # switch the module
                    _module._modules[name] = _tmp
                    # )
                    require_grad_params.extend(
                        list(_module._modules[name].Vida_up.parameters())
                    )
                    require_grad_params.extend(
                        list(_module._modules[name].Vida_down.parameters())
                    )
                    _module._modules[name].Vida_up.weight.requires_grad = True
                    _module._modules[name].Vida_down.weight.requires_grad = True

                    require_grad_params.extend(
                        list(_module._modules[name].Vida_up2.parameters())
                    )
                    require_grad_params.extend(
                        list(_module._modules[name].Vida_down2.parameters())
                    )
                    _module._modules[name].Vida_up2.weight.requires_grad = True
                    _module._modules[name].Vida_down2.weight.requires_grad = True                    
                    names.append(name)

    return require_grad_params, names


